{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "Lf7huAiYp-An"
      },
      "source": [
        "##### Copyright 2019 The TensorFlow Authors."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "cellView": "form",
        "colab": {},
        "colab_type": "code",
        "id": "YHz2D-oIqBWa"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "x44FFES-r6y0"
      },
      "source": [
        "# Federated Learning for Text Generation"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "iPFgLeZIsZ3Q"
      },
      "source": [
        "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca target=\"_blank\" href=\"https://www.tensorflow.org/federated/tutorials/federated_learning_for_text_generation\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" /\u003eView on TensorFlow.org\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/federated/blob/v0.8.0/docs/tutorials/federated_learning_for_text_generation.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/federated/blob/v0.8.0/docs/tutorials/federated_learning_for_text_generation.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "\u003c/table\u003e"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "KbNz2tuvsAFB"
      },
      "source": [
        "**NOTE**: This colab has been verified to work with the `0.7.0` version of the `tensorflow_federated` pip package, but the Tensorflow Federated project is still in pre-release development and may not work on `master`.\n",
        "\n",
        "This tutorial builds on the concepts in the [Federated Learning for Image Classification](federated_learning_for_image_classification.md) tutorial, and demonstrates several other useful approaches for federated learning.\n",
        "\n",
        "In particular, we load a previously trained Keras model, and refine it using  federated training on a (simulated) decentralized dataset. This is practically important for several reasons . The ability to use serialized models makes it easy to mix federated learning with other ML approaches. Further, this allows use of an increasing range of pre-trained models --- for example, training language models from scratch is rarely necessary, as numerous pre-trained models are now widely available (see, e.g., [TF Hub](https://www.tensorflow.org/hub)). Instead, it makes more sense to start from a pre-trained model, and refine it using Federated Learning, adapting to the particular characteristics of the decentralized data for a particular application.\n",
        "\n",
        "For this tutorial, we start with a RNN that generates ASCII characters, and refine it via federated learning. We also show how the final weights can be fed back to the original Keras model, allowing easy evaluation and text generation using standard tools."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "9LcC1AwjoqfR"
      },
      "outputs": [],
      "source": [
        "# NOTE: If you are running a Jupyter notebook, and installing a locally built\n",
        "# pip package, you may need to edit the following to point to the '.whl' file\n",
        "# on your local filesystem.\n",
        "\n",
        "!pip install --quiet tensorflow_federated\n",
        "!pip install --quiet tf-nightly\n",
        "\n",
        "# NOTE: Jupyter requires a patch to asyncio.\n",
        "!pip install --upgrade nest_asyncio\n",
        "import nest_asyncio\n",
        "nest_asyncio.apply()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 3,
      "metadata": {
        "colab": {
          "height": 87
        },
        "colab_type": "code",
        "executionInfo": {
          "elapsed": 18860,
          "status": "ok",
          "timestamp": 1568007329213,
          "user": {
            "displayName": "",
            "photoUrl": "",
            "userId": ""
          },
          "user_tz": 420
        },
        "id": "ZjDQysatrc2S",
        "outputId": "30c8ab46-8e0d-4d59-b46a-a947de1357eb"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "'Hello, World!'"
            ]
          },
          "execution_count": 3,
          "metadata": {
            "tags": []
          },
          "output_type": "execute_result"
        }
      ],
      "source": [
        "from __future__ import absolute_import, division, print_function\n",
        "\n",
        "import collections\n",
        "import functools\n",
        "import os\n",
        "import six\n",
        "import time\n",
        "\n",
        "import numpy as np\n",
        "import tensorflow as tf\n",
        "import tensorflow_federated as tff\n",
        "\n",
        "tf.compat.v1.enable_v2_behavior()\n",
        "\n",
        "np.random.seed(0)\n",
        "\n",
        "# Test the TFF is working:\n",
        "tff.federated_computation(lambda: 'Hello, World!')()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "lyICXwVAxvW9"
      },
      "source": [
        "# Load a pre-trained model\n",
        "\n",
        "We load a model that was pre-trained following the TensorFlow tutorial\n",
        "[Text generation using a RNN with eager execution](https://www.tensorflow.org/tutorials/sequences/text_generation). However,\n",
        "rather than training on [The Complete Works of Shakespeare](http://www.gutenberg.org/files/100/100-0.txt), we pre-trained the model on the text from the Charles Dickens'\n",
        "    [A Tale of Two Cities](http://www.ibiblio.org/pub/docs/books/gutenberg/9/98/98.txt)\n",
        "    and\n",
        "    [A Christmas Carol](http://www.ibiblio.org/pub/docs/books/gutenberg/4/46/46.txt).\n",
        " \n",
        "Other than expanding the vocabularly, we didn't modify the original tutorial, so this initial model isn't state-of-the-art, but it produces reasonable predictions and is sufficient for our tutorial purposes. The final model was saved with `tf.keras.models.save_model(include_optimizer=False)`.\n",
        "   \n",
        " We will use federated learning to fine-tune this model for Shakespeare in this tutorial, using a federated version of the data provided by TFF.\n",
        "\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "XgF8e2Ksyq1F"
      },
      "source": [
        "## Generate the vocab lookup tables"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "IlCgQBRVymwR"
      },
      "outputs": [],
      "source": [
        "# A fixed vocabularly of ASCII chars that occur in the works of Shakespeare and Dickens:\n",
        "vocab = list('dhlptx@DHLPTX $(,048cgkoswCGKOSW[_#\\'/37;?bfjnrvzBFJNRVZ\"\u0026*.26:\\naeimquyAEIMQUY]!%)-159\\r')\n",
        "\n",
        "# Creating a mapping from unique characters to indices\n",
        "char2idx = {u:i for i, u in enumerate(vocab)}\n",
        "idx2char = np.array(vocab)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "2EH6MFRdzAwd"
      },
      "source": [
        "## Load the pre-trained model and generate some text"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "iIK674SrtCTm"
      },
      "outputs": [],
      "source": [
        "def load_model(batch_size):\n",
        "  urls = {\n",
        "      1: 'https://storage.googleapis.com/tff-models-public/dickens_rnn.batch1.kerasmodel',\n",
        "      8: 'https://storage.googleapis.com/tff-models-public/dickens_rnn.batch8.kerasmodel'}\n",
        "  assert batch_size in urls, 'batch_size must be in ' + str(urls.keys())\n",
        "  url = urls[batch_size]\n",
        "  local_file = tf.keras.utils.get_file(os.path.basename(url), origin=url)  \n",
        "  return tf.keras.models.load_model(local_file, compile=False)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "WvuwZBX5Ogfd"
      },
      "outputs": [],
      "source": [
        "def generate_text(model, start_string):\n",
        "  # From https://www.tensorflow.org/tutorials/sequences/text_generation\n",
        "  num_generate = 200\n",
        "  input_eval = [char2idx[s] for s in start_string]\n",
        "  input_eval = tf.expand_dims(input_eval, 0)\n",
        "  text_generated = []\n",
        "  temperature = 1.0\n",
        "\n",
        "  model.reset_states()\n",
        "  for i in range(num_generate):\n",
        "    predictions = model(input_eval)\n",
        "    predictions = tf.squeeze(predictions, 0)\n",
        "    predictions = predictions / temperature\n",
        "    predicted_id = tf.multinomial(predictions, num_samples=1)[-1,0].numpy()      \n",
        "    input_eval = tf.expand_dims([predicted_id], 0)      \n",
        "    text_generated.append(idx2char[predicted_id])\n",
        "\n",
        "  return (start_string + ''.join(text_generated))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 7,
      "metadata": {
        "colab": {
          "height": 101
        },
        "colab_type": "code",
        "executionInfo": {
          "elapsed": 3056,
          "status": "ok",
          "timestamp": 1568007332571,
          "user": {
            "displayName": "",
            "photoUrl": "",
            "userId": ""
          },
          "user_tz": 420
        },
        "id": "MGAdStJ5wDPV",
        "outputId": "2772f785-c377-4226-e90a-0c704f0a8ec2"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "What of TensorFlow Federated, you ask? Say what kind then brought no\r\n",
            "confidence in any occupied air; \"I was writting out at Paris.\r\n",
            "\r\n",
            "\"I arrived at Mr. Lorry he had compended till\r\n",
            "your hand miss, to such paper to death himself against th\n"
          ]
        }
      ],
      "source": [
        "# Text generation requires a batch_size=1 model.\n",
        "keras_model_batch1 = load_model(batch_size=1)\n",
        "print(generate_text(keras_model_batch1, 'What of TensorFlow Federated, you ask? '))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "kKMUn-TlgxuP"
      },
      "source": [
        "# Load and Preprocess the Federated Shakespeare Data\n",
        "\n",
        "The `tff.simulation.datasets` package provides a variety of datasets that are split into \"clients\", where each client corresponds to a dataset on a particular device that might participate in federated learning.\n",
        "\n",
        "These datasets provide realistic non-IID data distributions that replicate in simulation the challenges of training on real decentralized data. Some of the pre-processing of this data was done using tools from the [Leaf project](https://arxiv.org/abs/1812.01097) ([github](https://github.com/TalwalkarLab/leaf))."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "di3nStTDg0qc"
      },
      "outputs": [],
      "source": [
        "train_data, test_data = tff.simulation.datasets.shakespeare.load_data()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "_iiY65Vv4QNK"
      },
      "source": [
        "The datasets provided by `shakespeare.load_data()` consist of a sequence of\n",
        "string `Tensors`, one for each line spoken by a particular character in a\n",
        "Shakespeare play. The client keys consist of the name of the play joined with\n",
        "the name of the character, so for example `MUCH_ADO_ABOUT_NOTHING_OTHELLO` corresponds to the lines for the character Othello in the play *Much Ado About Nothing*. Note that in a real federated learning scenario\n",
        "clients are never identified or tracked by ids, but for simulation it is useful\n",
        "to work with keyed datasets.\n",
        "\n",
        "Here, for example, we can look at some data from King Lear:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 9,
      "metadata": {
        "colab": {
          "height": 50
        },
        "colab_type": "code",
        "executionInfo": {
          "elapsed": 3148,
          "status": "ok",
          "timestamp": 1568007346407,
          "user": {
            "displayName": "",
            "photoUrl": "",
            "userId": ""
          },
          "user_tz": 420
        },
        "id": "FEKiy1ntmmnk",
        "outputId": "e4d853b2-c62d-49e7-d825-b49c23e21137"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "tf.Tensor(b'', shape=(), dtype=string)\n",
            "tf.Tensor(b'What?', shape=(), dtype=string)\n"
          ]
        }
      ],
      "source": [
        "# Here the play is \"The Tragedy of King Lear\" and the character is \"King\".\n",
        "raw_example_dataset = train_data.create_tf_dataset_for_client(\n",
        "    'THE_TRAGEDY_OF_KING_LEAR_KING')\n",
        "# To allow for future extensions, each entry x\n",
        "# is an OrderedDict with a single key 'snippets' which contains the text.\n",
        "for x in raw_example_dataset.take(2):\n",
        "  print(x['snippets'])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "kUnbI5Hp4sXg"
      },
      "source": [
        "We now use `tf.data.Dataset` transformations to prepare this data for training the char RNN loaded above. \n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "9kDkmGe-7No7"
      },
      "outputs": [],
      "source": [
        "# Input pre-processing parameters\n",
        "SEQ_LENGTH = 100\n",
        "BATCH_SIZE = 8\n",
        "BUFFER_SIZE = 10000  # For dataset shuffling"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "W95Of6Bwsrfc"
      },
      "outputs": [],
      "source": [
        "# Using a namedtuple with keys x and y as the output type of the\n",
        "# dataset keeps both TFF and Keras happy:\n",
        "BatchType = collections.namedtuple('BatchType', ['x', 'y'])\n",
        "\n",
        "# Construct a lookup table to map string chars to indexes,\n",
        "# using the vocab loaded above:\n",
        "table = tf.contrib.lookup.index_table_from_tensor(\n",
        "    mapping=vocab,\n",
        "    num_oov_buckets=0,\n",
        "    default_value=0)  \n",
        "  \n",
        "def to_ids(x):\n",
        "  s = tf.reshape(x['snippets'], shape=[1])\n",
        "  chars = tf.string_split(s, delimiter='').values\n",
        "  ids = table.lookup(chars)\n",
        "  return ids  \n",
        "\n",
        "def split_input_target(chunk):\n",
        "  input_text = tf.map_fn(lambda x: x[:-1], chunk)\n",
        "  target_text = tf.map_fn(lambda x: x[1:], chunk)\n",
        "  return BatchType(input_text, target_text)\n",
        "\n",
        "  \n",
        "def preprocess(dataset):  \n",
        "  return (\n",
        "      # Map ASCII chars to int64 indexes using the vocab\n",
        "      dataset.map(to_ids)\n",
        "      # Split into individual chars\n",
        "      .apply(tf.data.experimental.unbatch())\n",
        "      # Form example sequences of SEQ_LENGTH +1\n",
        "      .batch(SEQ_LENGTH + 1,  drop_remainder=True)\n",
        "      # Shuffle and form minibatches\n",
        "      .shuffle(BUFFER_SIZE).batch(BATCH_SIZE, drop_remainder=True)\n",
        "      # And finally split into (input, target) tuples,\n",
        "      # each of length SEQ_LENGTH.\n",
        "      .map(split_input_target))\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "Jw98HnKmEhuh"
      },
      "source": [
        "Note that in the formation of the original sequences and in the formation of\n",
        "batches above, we use `drop_remainder=True` for simplicity. This means that any\n",
        "characters (clients) that don't have at least `(SEQ_LENGTH + 1) * BATCH_SIZE`\n",
        "chars of text will have empty datasets. A typical approach to address this would\n",
        "be to pad the batches with a special token, and then mask the loss to not take\n",
        "the padding tokens into account.\n",
        "\n",
        "This would complicate the example somewhat, so for this tutorial we only use full batches, as in the\n",
        "[standard tutorial](https://www.tensorflow.org/tutorials/sequences/text_generation).\n",
        "However, in the federated setting this issue is more significant, because many\n",
        "users might have small datasets.\n",
        "\n",
        "Now we can preprocess our `raw_example_dataset`, and check the types:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 12,
      "metadata": {
        "colab": {
          "height": 34
        },
        "colab_type": "code",
        "executionInfo": {
          "elapsed": 330,
          "status": "ok",
          "timestamp": 1568007346947,
          "user": {
            "displayName": "",
            "photoUrl": "",
            "userId": ""
          },
          "user_tz": 420
        },
        "id": "7rTal7bksWwc",
        "outputId": "36e2c0fe-8994-49a9-ae9e-b82ff4e49c06"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "BatchType(x=tf.int64, y=tf.int64) BatchType(x=TensorShape([8, 100]), y=TensorShape([8, 100]))\n"
          ]
        }
      ],
      "source": [
        "example_dataset = preprocess(raw_example_dataset)\n",
        "print(tf.compat.v1.data.get_output_types(example_dataset), tf.compat.v1.data.get_output_shapes(example_dataset))  "
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "ePT8Oawm8SRP"
      },
      "source": [
        "# Compile the model and test on the preprocessed data\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "vEgDsz-48cAq"
      },
      "source": [
        "We loaded an uncompiled keras model, but in order to run `keras_model.evaluate`, we need to compile it with a loss and metrics. We will also compile in an optimizer, which will be used as the on-device optimizer in Federated Learning."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "RsuVZ5KMWnn8"
      },
      "source": [
        "The original tutorial didn't have char-level accuracy (the fraction\n",
        "of predictions where the highest probability was put on the correct\n",
        "next char). This is a useful metric, so we add it.\n",
        "However, we need to define a new metric class for this because \n",
        "our predictions have rank 3 (a vector of logits for each of the \n",
        "`BATCH_SIZE * SEQ_LENGTH` predictions), and `SparseCategoricalAccuracy`\n",
        "expects only rank 2 predictions."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "gOUiDBvmWlM9"
      },
      "outputs": [],
      "source": [
        "class FlattenedCategoricalAccuracy(tf.keras.metrics.SparseCategoricalAccuracy):\n",
        "\n",
        "  def __init__(self, name='accuracy', dtype=None):\n",
        "    super(FlattenedCategoricalAccuracy, self).__init__(name, dtype=dtype)\n",
        "\n",
        "  def update_state(self, y_true, y_pred, sample_weight=None):\n",
        "    y_true = tf.reshape(y_true, [-1, 1])\n",
        "    y_pred = tf.reshape(y_pred, [-1, len(vocab), 1])\n",
        "    return super(FlattenedCategoricalAccuracy, self).update_state(\n",
        "        y_true, y_pred, sample_weight)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "QH6vXHHeWkYN"
      },
      "outputs": [],
      "source": [
        "def compile(keras_model):\n",
        "  keras_model.compile(\n",
        "      optimizer=tf.keras.optimizers.SGD(lr=0.5),\n",
        "      loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
        "      metrics=[FlattenedCategoricalAccuracy()])\n",
        "  return keras_model"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "U2X9eFgt94PM"
      },
      "source": [
        "Now we can compile a model, and evaluate it on our `example_dataset`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 15,
      "metadata": {
        "colab": {
          "height": 118
        },
        "colab_type": "code",
        "executionInfo": {
          "elapsed": 4614,
          "status": "ok",
          "timestamp": 1568007351743,
          "user": {
            "displayName": "",
            "photoUrl": "",
            "userId": ""
          },
          "user_tz": 420
        },
        "id": "c3Xd-52-9zGa",
        "outputId": "3f1adf96-9542-4392-ac29-f1dfb62b5cd5"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Evaluating on an example Shakespeare character:\n",
            "1/1 [==============================] - 4s 4s/step - loss: 3.3485 - accuracy: 0.3900\n",
            "Expected accuracy for random guessing: 0.012\n",
            "Evaluating on completely random data:\n",
            "1/1 [==============================] - 0s 210ms/step - loss: 11.4624 - accuracy: 0.0113\n"
          ]
        },
        {
          "data": {
            "text/plain": [
              "[11.462431907653809, 0.01125]"
            ]
          },
          "execution_count": 15,
          "metadata": {
            "tags": []
          },
          "output_type": "execute_result"
        }
      ],
      "source": [
        "BATCH_SIZE = 8  # The training and eval batch size for the rest of this tutorial.\n",
        "keras_model = load_model(batch_size=BATCH_SIZE)\n",
        "\n",
        "compile(keras_model)\n",
        "\n",
        "# Confirm that loss is much lower on Shakespeare than on random data\n",
        "print('Evaluating on an example Shakespeare character:')\n",
        "keras_model.evaluate(example_dataset.take(1))\n",
        "\n",
        "# As a sanity check, we can construct some completely random data, where we expect\n",
        "# the accuracy to be essentially random:\n",
        "random_indexes = np.random.randint(\n",
        "    low=0, high=len(vocab), size=1 * BATCH_SIZE * (SEQ_LENGTH + 1))\n",
        "data = {\n",
        "    'snippets':\n",
        "        tf.constant(''.join(np.array(vocab)[random_indexes]), shape=[1, 1])\n",
        "}\n",
        "random_dataset = preprocess(tf.data.Dataset.from_tensor_slices(data))\n",
        "print('Expected accuracy for random guessing: {:.3f}'.format(1.0 / len(vocab)))\n",
        "print('Evaluating on completely random data:')\n",
        "keras_model.evaluate(random_dataset, steps=1)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "lH0WzL5L8Lm4"
      },
      "source": [
        "# Fine-tune the model with Federated Learning"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "NCao4M3L_tsA"
      },
      "source": [
        "TFF serializes all TensorFlow computations so they can potentially be run in a\n",
        "non-Python environment (even though at the moment, only a simulation runtime implemented in Python is available). Even though we are running in eager mode, (TF 2.0), currently TFF serializes TensorFlow computations by constructing the\n",
        "necessary ops inside the context of a \"`with tf.Graph.as_default()`\" statement.\n",
        "Thus, we need to provide a function that TFF can use to introduce our model into\n",
        "a graph it controls. We do this as follows:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "5KadIvFp7m6y"
      },
      "outputs": [],
      "source": [
        "# Clone the keras_model inside `create_tff_model()`, which TFF will\n",
        "# call to produce a new copy of the model inside the graph that it will serialize.\n",
        "def create_tff_model():\n",
        "  # TFF uses a `dummy_batch` so it knows the types and shapes\n",
        "  # that your model expects.\n",
        "  x = tf.constant(np.random.randint(1, len(vocab), size=[BATCH_SIZE, SEQ_LENGTH]))\n",
        "  dummy_batch = collections.OrderedDict([('x', x), ('y', x)]) \n",
        "  keras_model_clone = compile(tf.keras.models.clone_model(keras_model))\n",
        "  return tff.learning.from_compiled_keras_model(\n",
        "      keras_model_clone, dummy_batch=dummy_batch)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "ZJF_yhJxAi2l"
      },
      "source": [
        "Now we are ready to construct a Federated Averaging iterative process, which we will use to improve the model (for details on the Federated Averaging algorithm, see the paper [Communication-Efficient Learning of Deep Networks from Decentralized Data](https://arxiv.org/abs/1602.05629)).\n",
        "\n",
        "We use a compiled Keras model to perform standard (non-federated) evaluation after each round of federated training. This is useful for research purposes when doing simulated federated learning and there is a  standard test dataset. \n",
        "\n",
        "In a realistic production setting this same technique might be used to take models trained with federated learning and evaluate them on a centralized benchmark dataset for testing or quality assurance purposes."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "my3PW3qhAMDA"
      },
      "outputs": [],
      "source": [
        "# This command builds all the TensorFlow graphs and serializes them: \n",
        "fed_avg = tff.learning.build_federated_averaging_process(model_fn=create_tff_model)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "qVOkzs9C9kmv"
      },
      "source": [
        "Here is the simplest possible loop, where we run federated averaging for one round on a single client on a single batch:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 18,
      "metadata": {
        "colab": {
          "height": 34
        },
        "colab_type": "code",
        "executionInfo": {
          "elapsed": 18012,
          "status": "ok",
          "timestamp": 1568007377596,
          "user": {
            "displayName": "",
            "photoUrl": "",
            "userId": ""
          },
          "user_tz": 420
        },
        "id": "lrjUrkjq9jYk",
        "outputId": "ba6538f2-57f9-4a50-bb28-03dc527c84f7"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\u003caccuracy=0.013749999925494194,loss=4.454826354980469\u003e\n"
          ]
        }
      ],
      "source": [
        "state = fed_avg.initialize()\n",
        "state, metrics = fed_avg.next(state, [example_dataset.take(1)])\n",
        "print(metrics)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "o2CjvVg0FZpS"
      },
      "source": [
        "Now let's write a slightly more interesting training and evaluation loop.\n",
        "\n",
        "So that this simulation still runs relatively quickly,  we train on the same three clients each round, only considering two minibatches for each.\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "wE386-rbMCve"
      },
      "outputs": [],
      "source": [
        "def data(client, source=train_data):\n",
        "  return preprocess(\n",
        "      source.create_tf_dataset_for_client(client)).take(2)\n",
        "\n",
        "clients = ['ALL_S_WELL_THAT_ENDS_WELL_CELIA',\n",
        "           'MUCH_ADO_ABOUT_NOTHING_OTHELLO',\n",
        "           'THE_TRAGEDY_OF_KING_LEAR_KING']\n",
        "\n",
        "train_datasets = [data(client) for client in clients]\n",
        "\n",
        "# We concatenate the test datasets for evaluation with Keras.\n",
        "test_dataset = functools.reduce(\n",
        "    lambda d1, d2: d1.concatenate(d2),\n",
        "    [data(client, test_data) for client in clients])\n",
        "\n",
        "# NOTE: If the statement below fails, it means that you are\n",
        "# using an older version of TFF without the high-performance\n",
        "# executor stack. Call `tff.framework.set_default_executor()`\n",
        "# instead to use the default reference runtime.\n",
        "if six.PY3:\n",
        "  tff.framework.set_default_executor(tff.framework.create_local_executor())"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "cU3FuY00MOoX"
      },
      "source": [
        "The initial state of the model produced by `fed_avg.initialize()` is based\n",
        "on the random initializers for the Keras model, not the weights that were loaded,\n",
        "since `clone_model()` does not clone the weights. To start training\n",
        "from a pre-trained model, we set the model weights in the server state\n",
        "directly from the loaded model."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 20,
      "metadata": {
        "colab": {
          "height": 202
        },
        "colab_type": "code",
        "executionInfo": {
          "elapsed": 41479,
          "status": "ok",
          "timestamp": 1568007419742,
          "user": {
            "displayName": "",
            "photoUrl": "",
            "userId": ""
          },
          "user_tz": 420
        },
        "id": "vm_-PU8OFXpY",
        "outputId": "06d63b19-9b5e-4cc8-ff17-6f712ba1c4da"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Evaluating before training round 0\n",
            "2/2 [==============================] - 1s 452ms/step - loss: 3.2266 - accuracy: 0.4081\n",
            "Training metrics:  \u003caccuracy=0.4099999964237213,loss=3.2493884563446045\u003e\n",
            "Evaluating before training round 1\n",
            "2/2 [==============================] - 1s 442ms/step - loss: 2.9462 - accuracy: 0.4331\n",
            "Training metrics:  \u003caccuracy=0.42916667461395264,loss=2.899351119995117\u003e\n",
            "Evaluating before training round 2\n",
            "2/2 [==============================] - 1s 449ms/step - loss: 2.7754 - accuracy: 0.4500\n",
            "Training metrics:  \u003caccuracy=0.47083333134651184,loss=2.5915045738220215\u003e\n",
            "Evaluating before training round 4\n",
            "2/2 [==============================] - 1s 435ms/step - loss: 2.5748 - accuracy: 0.4837\n"
          ]
        }
      ],
      "source": [
        "NUM_ROUNDS = 3\n",
        "\n",
        "# The state of the FL server, containing the model and optimization state.\n",
        "state = fed_avg.initialize()\n",
        "\n",
        "state = tff.learning.state_with_new_model_weights(\n",
        "    state,\n",
        "    trainable_weights=[v.numpy() for v in keras_model.trainable_weights],\n",
        "    non_trainable_weights=[\n",
        "        v.numpy() for v in keras_model.non_trainable_weights\n",
        "    ])\n",
        "\n",
        "\n",
        "def keras_evaluate(state, round_num):\n",
        "  tff.learning.assign_weights_to_keras_model(keras_model, state.model)\n",
        "  print('Evaluating before training round', round_num)\n",
        "  keras_model.evaluate(example_dataset, steps=2)\n",
        "\n",
        "\n",
        "for round_num in range(NUM_ROUNDS):\n",
        "  keras_evaluate(state, round_num)\n",
        "  # N.B. The TFF runtime is currently fairly slow,\n",
        "  # expect this to get significantly faster in future releases.\n",
        "  state, metrics = fed_avg.next(state, train_datasets)\n",
        "  print('Training metrics: ', metrics)\n",
        "\n",
        "keras_evaluate(state, NUM_ROUNDS + 1)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "SoshvcHhXVa6"
      },
      "source": [
        "With the default changes, we haven't done enough training to make a big difference, but if you train longer on more Shakespeare data, you should see a difference in the style of the text generated with the updated model:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 21,
      "metadata": {
        "colab": {
          "height": 87
        },
        "colab_type": "code",
        "executionInfo": {
          "elapsed": 2819,
          "status": "ok",
          "timestamp": 1568007422587,
          "user": {
            "displayName": "",
            "photoUrl": "",
            "userId": ""
          },
          "user_tz": 420
        },
        "id": "NTUig7QmXavy",
        "outputId": "d8ea9c19-6043-4619-9da6-72d687de3600"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "What of TensorFlow Federated, you ask? Saying Doctor Manette's house but deny.\r\n",
            "\r\n",
            "                                                                                                                                                             \n"
          ]
        }
      ],
      "source": [
        "keras_model_batch1.set_weights([v.numpy() for v in keras_model.weights])\n",
        "# Text generation requires batch_size=1\n",
        "print(generate_text(keras_model_batch1, 'What of TensorFlow Federated, you ask? '))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "4DA1Fkf5mN0s"
      },
      "source": [
        "# Suggested extensions\n",
        "\n",
        "This tutorial is just the first step! Here are some ideas for how you might try extending this notebook:\n",
        "  * Write a more realistic training loop where you sample clients to train on randomly.\n",
        "  * Use \"`.repeat(NUM_EPOCHS)`\" on the client datasets to try multiple epochs of local training (e.g., as in [McMahan et. al.](https://arxiv.org/abs/1602.05629)). See also [Federated Learning for Image Classification](federated_learning_for_image_classification.md) which does this.\n",
        "  * Change the `compile()` command to experiment with using different optimization algorithms on the client.\n",
        "  * Try the `server_optimizer` argument to `build_federated_averaging_process` to try different algorithms for applying the model updates on the server.\n",
        "  * Try the `client_weight_fn` argument to to `build_federated_averaging_process` to try different weightings of the clients. The default weights client updates by the number of examples on the client, but you can do e.g. `client_weight_fn=lambda _: tf.constant(1.0)`."
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "last_runtime": {
        "build_target": "",
        "kind": "local"
      },
      "name": "Federated Learning for Text Generation",
      "provenance": [],
      "toc_visible": true,
      "version": "0.3.2"
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
